/* Copyright 2023 The MediaPipe Authors.

Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at

    http://www.apache.org/licenses/LICENSE-2.0

Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
==============================================================================*/

#include <dlfcn.h>

#include <cmath>
#include <cstdint>
#include <cstring>
#include <map>
#include <string>
#include <utility>
#include <vector>

#include "absl/log/absl_log.h"
#include "absl/status/status.h"
#include "mediapipe/framework/api2/node.h"
#include "mediapipe/framework/api2/port.h"
#include "mediapipe/framework/calculator_context.h"
#include "mediapipe/framework/calculator_framework.h"
#include "mediapipe/framework/deps/file_helpers.h"
#include "mediapipe/framework/formats/image_frame.h"
#include "mediapipe/framework/formats/tensor.h"
#include "mediapipe/tasks/cc/vision/image_generator/diffuser/diffuser_gpu.h"
#include "mediapipe/tasks/cc/vision/image_generator/diffuser/stable_diffusion_iterate_calculator.pb.h"

namespace mediapipe {
namespace api2 {
namespace {

DiffuserPriorityHint ToDiffuserPriorityHint(
    StableDiffusionIterateCalculatorOptions::ClPriorityHint priority) {
  switch (priority) {
    case StableDiffusionIterateCalculatorOptions::PRIORITY_HINT_LOW:
      return kDiffuserPriorityHintLow;
    case StableDiffusionIterateCalculatorOptions::PRIORITY_HINT_NORMAL:
      return kDiffuserPriorityHintNormal;
    case StableDiffusionIterateCalculatorOptions::PRIORITY_HINT_HIGH:
      return kDiffuserPriorityHintHigh;
  }
  return kDiffuserPriorityHintNormal;
}

DiffuserModelType ToDiffuserModelType(
    StableDiffusionIterateCalculatorOptions::ModelType model_type) {
  switch (model_type) {
    case StableDiffusionIterateCalculatorOptions::DEFAULT:
    case StableDiffusionIterateCalculatorOptions::SD_1:
      return kDiffuserModelTypeSd1;
  }
  return kDiffuserModelTypeSd1;
}

}  // namespace

// Runs diffusion models including, but not limited to, Stable Diffusion & gLDM.
//
// Inputs:
//   PROMPT - std::string
//     The prompt used to generate the image.
//   STEPS - int
//     The number of steps to run the UNet.
//   ITERATION - int
//     The iteration of the current run.
//   PLUGIN_TENSORS - std::vector<mediapipe::Tensor> @Optional
//     The output tensor vector of the diffusion plugins model.
//   PLUGIN_STRENGTH - float @Optional
//     The strength of the plugin tensors.
//   SHOW_RESULT - bool @Optional
//     Whether to show the diffusion result at the current step, regardless
//     of what show_every_n_iteration is set to.
//
// Outputs:
//   IMAGE - mediapipe::ImageFrame
//     The image generated by the Stable Diffusion model from the input prompt.
//     The output image is in RGB format.
//
// Example:
// node {
//   calculator: "StableDiffusionIterateCalculator"
//   input_stream: "PROMPT:prompt"
//   input_stream: "STEPS:steps"
//   output_stream: "IMAGE:result"
//   options {
//     [mediapipe.StableDiffusionIterateCalculatorOptions.ext] {
//       base_seed: 0
//       model_type: SD_1
//     }
//   }
// }
class StableDiffusionIterateCalculator : public Node {
 public:
  static constexpr Input<std::string> kPromptIn{"PROMPT"};
  static constexpr Input<int> kStepsIn{"STEPS"};
  static constexpr Input<int>::Optional kIterationIn{"ITERATION"};
  static constexpr Input<int>::Optional kRandSeedIn{"RAND_SEED"};
  static constexpr SideInput<StableDiffusionIterateCalculatorOptions>::Optional
      kOptionsIn{"OPTIONS"};
  static constexpr Input<std::vector<Tensor>>::Optional kPlugInTensorsIn{
      "PLUGIN_TENSORS"};
  static constexpr Input<float>::Optional kPluginStrengthIn{"PLUGIN_STRENGTH"};
  static constexpr Input<bool>::Optional kShowResultIn{"SHOW_RESULT"};
  static constexpr Output<mediapipe::ImageFrame> kImageOut{"IMAGE"};
  MEDIAPIPE_NODE_CONTRACT(kPromptIn, kStepsIn, kIterationIn, kRandSeedIn,
                          kPlugInTensorsIn, kPluginStrengthIn, kShowResultIn,
                          kOptionsIn, kImageOut);

  ~StableDiffusionIterateCalculator() {
    if (context_) DiffuserDelete();
    if (handle_) dlclose(handle_);
  }

  static absl::Status UpdateContract(CalculatorContract* cc);

  absl::Status Open(CalculatorContext* cc) override;
  absl::Status Process(CalculatorContext* cc) override;

 private:
  std::vector<DiffuserPluginTensor> GetPluginTensors(
      CalculatorContext* cc) const {
    if (!kPlugInTensorsIn(cc).IsConnected()) return {};
    std::vector<DiffuserPluginTensor> diffuser_tensors;
    diffuser_tensors.reserve(kPlugInTensorsIn(cc)->size());
    for (const auto& mp_tensor : *kPlugInTensorsIn(cc)) {
      DiffuserPluginTensor diffuser_tensor;
      diffuser_tensor.shape[0] = mp_tensor.shape().dims[0];
      diffuser_tensor.shape[1] = mp_tensor.shape().dims[1];
      diffuser_tensor.shape[2] = mp_tensor.shape().dims[2];
      diffuser_tensor.shape[3] = mp_tensor.shape().dims[3];
      diffuser_tensor.data = mp_tensor.GetCpuReadView().buffer<float>();
      diffuser_tensors.push_back(diffuser_tensor);
    }
    return diffuser_tensors;
  }

  absl::Status LoadDiffuser() {
    handle_ = dlopen("libimagegenerator_gpu.so", RTLD_NOW | RTLD_LOCAL);
    RET_CHECK(handle_) << dlerror();
    create_ptr_ = reinterpret_cast<DiffuserContext* (*)(const DiffuserConfig*)>(
        dlsym(handle_, "DiffuserCreate"));
    RET_CHECK(create_ptr_) << dlerror();
    reset_ptr_ =
        reinterpret_cast<int (*)(DiffuserContext*, const char*, int, int, float,
                                 const void*)>(dlsym(handle_, "DiffuserReset"));
    RET_CHECK(reset_ptr_) << dlerror();
    iterate_ptr_ = reinterpret_cast<int (*)(DiffuserContext*, int, int)>(
        dlsym(handle_, "DiffuserIterate"));
    RET_CHECK(iterate_ptr_) << dlerror();
    decode_ptr_ = reinterpret_cast<int (*)(DiffuserContext*, uint8_t*)>(
        dlsym(handle_, "DiffuserDecode"));
    RET_CHECK(decode_ptr_) << dlerror();
    delete_ptr_ = reinterpret_cast<void (*)(DiffuserContext*)>(
        dlsym(handle_, "DiffuserDelete"));
    RET_CHECK(delete_ptr_) << dlerror();
    return absl::OkStatus();
  }

  DiffuserContext* DiffuserCreate(const DiffuserConfig* a) {
    return (*create_ptr_)(a);
  }
  bool DiffuserReset(const char* a, int b, int c, float d,
                     const std::vector<DiffuserPluginTensor>* e) {
    return (*reset_ptr_)(context_, a, b, c, d, e);
  }
  bool DiffuserIterate(int a, int b) { return (*iterate_ptr_)(context_, a, b); }
  bool DiffuserDecode(uint8_t* a) { return (*decode_ptr_)(context_, a); }
  void DiffuserDelete() { (*delete_ptr_)(context_); }

  void* handle_ = nullptr;
  DiffuserContext* context_ = nullptr;
  DiffuserContext* (*create_ptr_)(const DiffuserConfig*);
  int (*reset_ptr_)(DiffuserContext*, const char*, int, int, float,
                    const void*);
  int (*iterate_ptr_)(DiffuserContext*, int, int);
  int (*decode_ptr_)(DiffuserContext*, uint8_t*);
  void (*delete_ptr_)(DiffuserContext*);

  int show_every_n_iteration_;
  bool emit_empty_packet_;
};

absl::Status StableDiffusionIterateCalculator::UpdateContract(
    CalculatorContract* cc) {
  return absl::OkStatus();
}

absl::Status StableDiffusionIterateCalculator::Open(CalculatorContext* cc) {
  StableDiffusionIterateCalculatorOptions options;
  if (kOptionsIn(cc).IsEmpty()) {
    options = cc->Options<StableDiffusionIterateCalculatorOptions>();
  } else {
    options = kOptionsIn(cc).Get();
  }
  show_every_n_iteration_ = options.show_every_n_iteration();
  emit_empty_packet_ = options.emit_empty_packet();

  MP_RETURN_IF_ERROR(LoadDiffuser());

  DiffuserConfig config;
  config.model_type = ToDiffuserModelType(options.model_type());
  if (config.model_type == kDiffuserModelTypeTigo) {
    config.run_unet_with_masked_image = 1;
  } else {
    config.run_unet_with_masked_image = 0;
  }
  if (options.file_folder().empty()) {
    std::strcpy(config.model_dir, "bins/");  // NOLINT
  } else {
    std::string file_folder = options.file_folder();
    if (!file_folder.empty() && file_folder.back() != '/') {
      file_folder.push_back('/');
    }
    std::strcpy(config.model_dir, file_folder.c_str());  // NOLINT
  }
  MP_RETURN_IF_ERROR(mediapipe::file::Exists(config.model_dir))
      << config.model_dir;
  RET_CHECK(options.lora_file_folder().empty() ||
            options.lora_weights_layer_mapping().empty())
      << "Can't set both lora_file_folder and lora_weights_layer_mapping.";
  std::strcpy(config.lora_dir, options.lora_file_folder().c_str());  // NOLINT
  std::map<std::string, const char*> lora_weights_layer_mapping;
  for (auto& layer_name_and_weights : options.lora_weights_layer_mapping()) {
    lora_weights_layer_mapping[layer_name_and_weights.first] =
        (char*)layer_name_and_weights.second;
  }
  config.lora_weights_layer_mapping = !lora_weights_layer_mapping.empty()
                                          ? &lora_weights_layer_mapping
                                          : nullptr;
  config.lora_rank = options.lora_rank();
  config.seed = options.base_seed();
  config.image_width = options.output_image_width();
  config.image_height = options.output_image_height();
  config.run_unet_with_plugins = kPlugInTensorsIn(cc).IsConnected();
  config.env_options = {
      .priority_hint = ToDiffuserPriorityHint(options.cl_priority_hint()),
      .performance_hint = kDiffuserPerformanceHintHigh,
  };
  RET_CHECK(options.plugins_strength() >= 0.0f ||
            options.plugins_strength() <= 1.0f)
      << "The value of plugins_strength must be in the range of [0, 1].";
  context_ = DiffuserCreate(&config);
  RET_CHECK(context_);
  return absl::OkStatus();
}

absl::Status StableDiffusionIterateCalculator::Process(CalculatorContext* cc) {
  const auto& options =
      cc->Options().GetExtension(StableDiffusionIterateCalculatorOptions::ext);
  const std::string& prompt = *kPromptIn(cc);
  const int steps = *kStepsIn(cc);
  const int rand_seed = !kRandSeedIn(cc).IsEmpty() ? std::abs(*kRandSeedIn(cc))
                                                   : options.base_seed();
  float plugins_strength = options.plugins_strength();
  if (kPluginStrengthIn(cc).IsConnected() && !kPluginStrengthIn(cc).IsEmpty()) {
    plugins_strength = kPluginStrengthIn(cc).Get();
    RET_CHECK(plugins_strength >= 0.0f || plugins_strength <= 1.0f)
        << "The value of plugins_strength must be in the range of [0, 1].";
  }

  if (kIterationIn(cc).IsEmpty()) {
    const auto plugin_tensors = GetPluginTensors(cc);
    RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed, plugins_strength,
                            &plugin_tensors));
    for (int i = 0; i < steps; i++) RET_CHECK(DiffuserIterate(steps, i));
    ImageFrame image_out(ImageFormat::SRGB, options.output_image_width(),
                         options.output_image_height());
    RET_CHECK(DiffuserDecode(image_out.MutablePixelData()));
    kImageOut(cc).Send(std::move(image_out));
  } else {
    const int iteration = *kIterationIn(cc);
    RET_CHECK_LT(iteration, steps);

    // Extract text embedding on first iteration.
    if (iteration == 0) {
      const auto plugin_tensors = GetPluginTensors(cc);
      RET_CHECK(DiffuserReset(prompt.c_str(), steps, rand_seed,
                              plugins_strength, &plugin_tensors));
    }

    RET_CHECK(DiffuserIterate(steps, iteration));

    bool force_show_result = kShowResultIn(cc).IsConnected() &&
                             !kShowResultIn(cc).IsEmpty() &&
                             kShowResultIn(cc).Get();
    bool show_result = force_show_result ||
                       (iteration + 1) % show_every_n_iteration_ == 0 ||
                       iteration == steps - 1;
    // Decode the output and send out the image for visualization.
    if (show_result) {
      ImageFrame image_out(ImageFormat::SRGB, options.output_image_width(),
                           options.output_image_height());
      RET_CHECK(DiffuserDecode(image_out.MutablePixelData()));
      kImageOut(cc).Send(std::move(image_out));
    } else if (emit_empty_packet_) {
      kImageOut(cc).Send(Packet<mediapipe::ImageFrame>());
    }
  }
  return absl::OkStatus();
}

MEDIAPIPE_REGISTER_NODE(StableDiffusionIterateCalculator);

}  // namespace api2
}  // namespace mediapipe
